{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Graph4NLP Demo: Math Word Problem\n",
    "\n",
    "---\n",
    "\n",
    "In this demo, we will have a closer look at how to apply **Graph2Tree model to the task of math word problem automatically solving**.\n",
    "Math word problem solving aims to infer reasonable equations from given natural language problem descriptions. It is important for exploring automatic solutions to mathematical problems and improving the reasoning ability of neural networks. \n",
    "In this demo, we use the Graph4NLP library to build a GNN-based math word problem (MWP) solving model. \n",
    "\n",
    "The **Graph2Tree** model consists of:\n",
    "\n",
    "- graph construction module (e.g., node embedding based dynamic graph)\n",
    "- graph embedding module (e.g., undirected GraphSage)\n",
    "- predictoin module (e.g., tree decoder with attention and copy mechanisms)\n",
    "\n",
    "As shown in the picture below, we firstly construct graph input from problem description by syntactic parsing (CoreNLP) and then represent the output equation with a hierarchical structure (Node ``N`` stands for non-terminal node).\n",
    "\n",
    "<p align=\"center\">\n",
    "<img src=\"./imgs/g2t.png\" width=\"600\" class=\"center\" alt=\"graph2tree_mwp\"/>\n",
    "    <br/>\n",
    "</p>\n",
    "\n",
    "We will use the built-in Graph2Tree model APIs to build the model, and evaluate it on the Mawps dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Environment setup\n",
    "---\n",
    "1. Create virtual environment\n",
    "```\n",
    "conda create --name graph4nlp python=3.7\n",
    "conda activate graph4nlp\n",
    "```\n",
    "\n",
    "2. Install [graph4nlp](https://github.com/graph4ai/graph4nlp) library\n",
    "- Clone the github repo\n",
    "```\n",
    "git clone -b stable https://github.com/graph4ai/graph4nlp.git\n",
    "cd graph4nlp\n",
    "```\n",
    "- Then run `./configure` (or `./configure.bat` if you are using Windows 10) to config your installation. The configuration program will ask you to specify your CUDA version. If you do not have a GPU, please choose 'cpu'.\n",
    "```\n",
    "./configure\n",
    "```\n",
    "- Finally, install the package\n",
    "```\n",
    "python setup.py install\n",
    "```\n",
    "\n",
    "3. Set up StanfordCoreNLP (for static graph construction only, unnecessary for this demo because preprocessed data is provided)\n",
    "- Download [StanfordCoreNLP](https://stanfordnlp.github.io/CoreNLP/)\n",
    "- Go to the root folder and start the server\n",
    "```\n",
    "java -mx4g -cp \"*\" edu.stanford.nlp.pipeline.StanfordCoreNLPServer -port 9000 -timeout 15000\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-22T13:51:01.626460Z",
     "start_time": "2021-07-22T13:51:01.615750Z"
    }
   },
   "source": [
    "## Load the config file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-27T10:13:55.123979Z",
     "start_time": "2021-07-27T10:13:55.092616Z"
    },
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from graph4nlp.pytorch.modules.config import get_basic_args\n",
    "from graph4nlp.pytorch.modules.utils.config_utils import update_values, get_yaml_config\n",
    "\n",
    "def get_args():\n",
    "    config = {'dataset_yaml': \"./config.yaml\",\n",
    "              'learning_rate': 1e-3,\n",
    "              'gpuid': -1,\n",
    "              'seed': 123, \n",
    "              'init_weight': 0.08,\n",
    "              'graph_type': 'static',\n",
    "              'weight_decay': 0, \n",
    "              'max_epochs': 20, \n",
    "              'min_freq': 1,\n",
    "              'grad_clip': 5,\n",
    "              'batch_size': 20,\n",
    "              'share_vocab': True,\n",
    "              'pretrained_word_emb_name': None,\n",
    "              'pretrained_word_emb_url': None,\n",
    "              'pretrained_word_emb_cache_dir': \".vector_cache\",\n",
    "              'checkpoint_save_path': \"./checkpoint_save\",\n",
    "              'beam_size': 4\n",
    "              }\n",
    "    our_args = get_yaml_config(config['dataset_yaml'])\n",
    "    template = get_basic_args(graph_construction_name=our_args[\"graph_construction_name\"],\n",
    "                              graph_embedding_name=our_args[\"graph_embedding_name\"],\n",
    "                              decoder_name=our_args[\"decoder_name\"])\n",
    "    update_values(to_args=template, from_args_list=[our_args, config])\n",
    "    return template\n",
    "\n",
    "# show our config\n",
    "cfg_g2t = get_args()\n",
    "from pprint import pprint\n",
    "pprint(cfg_g2t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-27T10:13:55.452276Z",
     "start_time": "2021-07-27T10:13:55.364294Z"
    }
   },
   "outputs": [],
   "source": [
    "import copy\n",
    "import torch\n",
    "import random\n",
    "import argparse\n",
    "import numpy as np\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from graph4nlp.pytorch.data.data import to_batch\n",
    "from graph4nlp.pytorch.datasets.mawps import MawpsDatasetForTree\n",
    "from graph4nlp.pytorch.modules.graph_construction import DependencyBasedGraphConstruction\n",
    "from graph4nlp.pytorch.modules.graph_embedding import *\n",
    "from graph4nlp.pytorch.models.graph2tree import Graph2Tree\n",
    "from graph4nlp.pytorch.modules.utils.tree_utils import Tree, prepare_oov\n",
    "\n",
    "from utils import convert_to_string, compute_tree_accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-27T10:13:55.669860Z",
     "start_time": "2021-07-27T10:13:55.640456Z"
    }
   },
   "outputs": [],
   "source": [
    "class Mawps:\n",
    "    def __init__(self, opt=None):\n",
    "        super(Mawps, self).__init__()\n",
    "        self.opt = opt\n",
    "\n",
    "        seed = self.opt[\"seed\"]\n",
    "        random.seed(seed)\n",
    "        np.random.seed(seed)\n",
    "        torch.manual_seed(seed)\n",
    "\n",
    "        if self.opt[\"gpuid\"] == -1:\n",
    "            self.device = torch.device(\"cpu\")\n",
    "        else:\n",
    "            self.device = torch.device(\"cuda:{}\".format(self.opt[\"gpuid\"]))\n",
    "\n",
    "        self.use_copy = self.opt[\"decoder_args\"][\"rnn_decoder_share\"][\"use_copy\"]\n",
    "        self.use_share_vocab = self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\"share_vocab\"]\n",
    "        self.data_dir = self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\"root_dir\"]\n",
    "\n",
    "        self._build_dataloader()\n",
    "        self._build_model()\n",
    "        self._build_optimizer()\n",
    "\n",
    "    def _build_dataloader(self):\n",
    "        para_dic =  {'root_dir': self.data_dir,\n",
    "                    'word_emb_size': self.opt[\"graph_construction_args\"][\"node_embedding\"][\"input_size\"],\n",
    "                    'topology_builder': DependencyBasedGraphConstruction,\n",
    "                    'topology_subdir': self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\"topology_subdir\"], \n",
    "                    'edge_strategy': self.opt[\"graph_construction_args\"][\"graph_construction_private\"][\"edge_strategy\"],\n",
    "                    'graph_type': 'static',\n",
    "                    'dynamic_graph_type': self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\"graph_type\"], \n",
    "                    'share_vocab': self.use_share_vocab, \n",
    "                    'enc_emb_size': self.opt[\"graph_construction_args\"][\"node_embedding\"][\"input_size\"],\n",
    "                    'dec_emb_size': self.opt[\"decoder_args\"][\"rnn_decoder_share\"][\"input_size\"],\n",
    "                    'dynamic_init_topology_builder': None,\n",
    "                    'min_word_vocab_freq': self.opt[\"min_freq\"],\n",
    "                    'pretrained_word_emb_name': self.opt[\"pretrained_word_emb_name\"],\n",
    "                    'pretrained_word_emb_url': self.opt[\"pretrained_word_emb_url\"], \n",
    "                    'pretrained_word_emb_cache_dir': self.opt[\"pretrained_word_emb_cache_dir\"]\n",
    "                    }\n",
    "\n",
    "        dataset = MawpsDatasetForTree(**para_dic)\n",
    "\n",
    "        self.train_data_loader = DataLoader(dataset.train, batch_size=self.opt[\"batch_size\"], shuffle=True,\n",
    "                                            num_workers=0,\n",
    "                                            collate_fn=dataset.collate_fn)\n",
    "        self.test_data_loader = DataLoader(dataset.test, batch_size=1, shuffle=False, num_workers=0,\n",
    "                                           collate_fn=dataset.collate_fn)\n",
    "        self.valid_data_loader = DataLoader(dataset.val, batch_size=1, shuffle=False, num_workers=0,\n",
    "                                          collate_fn=dataset.collate_fn)\n",
    "        self.vocab_model = dataset.vocab_model\n",
    "        self.src_vocab = self.vocab_model.in_word_vocab\n",
    "        self.tgt_vocab = self.vocab_model.out_word_vocab\n",
    "        self.share_vocab = self.vocab_model.share_vocab if self.use_share_vocab else None\n",
    "\n",
    "    def _build_model(self):\n",
    "        '''For encoder-decoder'''\n",
    "        self.model = Graph2Tree.from_args(self.opt, vocab_model=self.vocab_model)\n",
    "        self.model.init(self.opt[\"init_weight\"])\n",
    "        self.model.to(self.device)\n",
    "\n",
    "    def _build_optimizer(self):\n",
    "        optim_state = {\"learningRate\": self.opt[\"learning_rate\"], \"weight_decay\": self.opt[\"weight_decay\"]}\n",
    "        parameters = [p for p in self.model.parameters() if p.requires_grad]\n",
    "        self.optimizer = optim.Adam(parameters, lr=optim_state['learningRate'], weight_decay=optim_state['weight_decay'])\n",
    "\n",
    "    def train_epoch(self, epoch):\n",
    "        loss_to_print = 0\n",
    "        num_batch = len(self.train_data_loader)\n",
    "        for step, data in tqdm(enumerate(self.train_data_loader), desc=f'Epoch {epoch:02d}', total=len(self.train_data_loader)):\n",
    "            batch_graph, batch_tree_list, batch_original_tree_list = data['graph_data'], data['dec_tree_batch'], data['original_dec_tree_batch']\n",
    "            batch_graph = batch_graph.to(self.device)\n",
    "            self.optimizer.zero_grad()\n",
    "            oov_dict = prepare_oov(\n",
    "                batch_graph, self.src_vocab, self.device) if self.use_copy else None\n",
    "\n",
    "            if self.use_copy:\n",
    "                batch_tree_list_refined = []\n",
    "                for item in batch_original_tree_list:\n",
    "                    tgt_list = oov_dict.get_symbol_idx_for_list(item.strip().split())\n",
    "                    tgt_tree = Tree.convert_to_tree(tgt_list, 0, len(tgt_list), oov_dict)\n",
    "                    batch_tree_list_refined.append(tgt_tree)\n",
    "            loss = self.model(batch_graph, batch_tree_list_refined if self.use_copy else batch_tree_list, oov_dict=oov_dict)\n",
    "            loss.backward()\n",
    "            torch.nn.utils.clip_grad_value_(\n",
    "                self.model.parameters(), self.opt[\"grad_clip\"])\n",
    "            self.optimizer.step()\n",
    "            loss_to_print += loss\n",
    "        return loss_to_print/num_batch\n",
    "\n",
    "    def train(self):\n",
    "        best_acc = -1\n",
    "        best_model = None\n",
    "\n",
    "        print(\"-------------\\nStarting training.\")\n",
    "        for epoch in range(1, self.opt[\"max_epochs\"]+1):\n",
    "            self.model.train()\n",
    "            loss_to_print = self.train_epoch(epoch)\n",
    "            print(\"epochs = {}, train_loss = {:.3f}\".format(epoch, loss_to_print))\n",
    "            if epoch > 15:\n",
    "                val_acc = self.eval(self.model, mode=\"val\")\n",
    "                if val_acc > best_acc:\n",
    "                    best_acc = val_acc\n",
    "                    best_model = self.model\n",
    "        self.eval(best_model, mode=\"test\")\n",
    "        best_model.save_checkpoint(self.opt[\"checkpoint_save_path\"], \"best.pt\")\n",
    "\n",
    "    def eval(self, model, mode=\"val\"):\n",
    "        model.eval()\n",
    "        reference_list = []\n",
    "        candidate_list = []\n",
    "        data_loader = self.test_data_loader if mode == \"test\" else self.valid_data_loader\n",
    "        for data in tqdm(data_loader, desc=\"Eval: \"):\n",
    "            eval_input_graph, batch_tree_list, batch_original_tree_list = data['graph_data'], data['dec_tree_batch'], data['original_dec_tree_batch']\n",
    "            eval_input_graph = eval_input_graph.to(self.device)\n",
    "            oov_dict = prepare_oov(eval_input_graph, self.src_vocab, self.device)\n",
    "\n",
    "            if self.use_copy:\n",
    "                assert len(batch_original_tree_list) == 1\n",
    "                reference = oov_dict.get_symbol_idx_for_list(batch_original_tree_list[0].split())\n",
    "                eval_vocab = oov_dict\n",
    "            else:\n",
    "                assert len(batch_original_tree_list) == 1\n",
    "                reference = model.tgt_vocab.get_symbol_idx_for_list(batch_original_tree_list[0].split())\n",
    "                eval_vocab = self.tgt_vocab\n",
    "            \n",
    "            candidate = model.translate(eval_input_graph,\n",
    "                                        oov_dict=oov_dict,\n",
    "                                        use_beam_search=True,\n",
    "                                        beam_size=self.opt[\"beam_size\"])\n",
    "            \n",
    "            candidate = [int(c) for c in candidate]\n",
    "            num_left_paren = sum(\n",
    "                1 for c in candidate if eval_vocab.idx2symbol[int(c)] == \"(\")\n",
    "            num_right_paren = sum(\n",
    "                1 for c in candidate if eval_vocab.idx2symbol[int(c)] == \")\")\n",
    "            diff = num_left_paren - num_right_paren\n",
    "            if diff > 0:\n",
    "                for i in range(diff):\n",
    "                    candidate.append(\n",
    "                        self.test_data_loader.tgt_vocab.symbol2idx[\")\"])\n",
    "            elif diff < 0:\n",
    "                candidate = candidate[:diff]\n",
    "            ref_str = convert_to_string(\n",
    "                reference, eval_vocab)\n",
    "            cand_str = convert_to_string(\n",
    "                candidate, eval_vocab)\n",
    "\n",
    "            reference_list.append(reference)\n",
    "            candidate_list.append(candidate)\n",
    "        eval_acc = compute_tree_accuracy(\n",
    "            candidate_list, reference_list, eval_vocab)\n",
    "        print(\"{} accuracy = {:.3f}\\n\".format(mode, eval_acc))\n",
    "        return eval_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-27T09:12:35.052975Z",
     "start_time": "2021-07-27T09:12:34.827787Z"
    }
   },
   "outputs": [],
   "source": [
    "!rm ./data/processed/DependencyGraph/*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-27T10:13:58.783006Z",
     "start_time": "2021-07-27T10:13:57.875582Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "a = Mawps(cfg_g2t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-27T10:14:07.766389Z",
     "start_time": "2021-07-27T10:13:59.534775Z"
    }
   },
   "outputs": [],
   "source": [
    "best_acc = a.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "53b113626f4385b7a52334a3b11ec5d9307ad80c73f59f759f44504bc95f0ff2"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
